#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 16 10:14:39 2023

@author: jcfq2
"""


import numpy as np
# from numpy.core.umath_tests import matrix_multiply


def spherical_to_cartesian(latit, longit,radius):
    
    theta=np.deg2rad(latit+90)
    phi=np.deg2rad(longit)
    
    
    x = radius*np.cos(phi) * np.sin(theta)
    y = radius* np.sin(phi) * np.sin(theta)
    z = radius* np.cos(theta)
    return x, y, z

def cartesian_to_spherical(x, y, z):
    radius = np.sqrt(x ** 2 + y ** 2 + z ** 2)
    theta = np.arctan2(np.sqrt(x ** 2 + y ** 2), z)
    phi = np.arctan2(y, x)
    if x.any() < 0: phi[x < 0] = np.arctan2(y[x < 0], x[x < 0]) + np.pi
    
    theta_d=np.rad2deg(theta)
    phi_d=np.rad2deg(phi)

    phi_d[phi_d < 0]=phi_d[phi_d < 0]+360
    theta_d=theta_d-90

    return theta_d, phi_d, radius

def unit_vector(data, axis=None, out=None):

    if out is None:
        data = np.array(data, dtype=np.float64, copy=True)
        if data.ndim == 1:
            data /= np.sqrt(np.dot(data, data))
            return data
    else:
        if out is not data:
            out[:] = np.array(data, copy=False)
        data = out
    length = np.atleast_1d(np.sum(data*data, axis))
    np.sqrt(length, length)
    if axis is not None:
        length = np.expand_dims(length, axis)
    data /= length
    if out is None:
        return data

def rotation_matrix(angle, direction, point=None):

    sina = np.sin(angle)
    cosa = np.cos(angle)
    direction = unit_vector(direction[:3])
    # rotation matrix around unit vector
    R = np.diag([cosa, cosa, cosa])
    R += np.outer(direction, direction) * (1.0 - cosa)
    direction *= sina
    R += np.array([[ 0.0,         -direction[2],  direction[1]],
                      [ direction[2], 0.0,          -direction[0]],
                      [-direction[1], direction[0],  0.0]])
    M = np.identity(4)
    M[:3, :3] = R
    if point is not None:
        # rotation not around origin
        point = np.array(point[:3], dtype=np.float64, copy=False)
        M[:3, 3] = point - np.dot(R, point)
    return M


def rotatePole(lats, lons, radius, angle=90, axis=[1,0,0]):
    """
    Rotates the given geodetic lat/lon coordinates around the origin.
    
    :param lats, lons: shape (n,) in radians
    :param altitude: in km
    :param angle: degrees
    :param axis: [1, 0, 0], [0, 1, 0], or [0, 0, 1] for x y z axis
    :rtype: tuple (lats, lons) in radians
    """
    assert lats.ndim == 1 and lons.ndim == 1
    assert len(axis) == 3    

    x,y,z = spherical_to_cartesian(lats, lons, radius)
    xyz = np.asarray([x,y,z]).T
    
    alpha = np.deg2rad(angle)
    rot = rotation_matrix(alpha, axis)[:3,:3]
    
    # xyzRot = matrix_multiply(rot,xyz[...,np.newaxis]).reshape(xyz.shape)
    xyzRot = np.matmul(rot,xyz[...,np.newaxis]).reshape(xyz.shape)
    lats, lons, radius = cartesian_to_spherical(xyzRot[:,0], xyzRot[:,1], xyzRot[:,2])
    

    
    return lats, lons, radius

